{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn  \n",
    "import torch.optim as optim  \n",
    "import random  \n",
    "# 定义多层感知机模型  \n",
    "class MLP(nn.Module):  \n",
    "    def __init__(self):  \n",
    "        super(MLP, self).__init__()  \n",
    "        self.layer1 = nn.Linear(1, 64)  # 输入层到隐藏层  (公鸡)\n",
    "        self.layer2 = nn.Linear(64, 64) # 隐藏层到隐藏层  \n",
    "        self.layer3 = nn.Linear(64, 2)  # 隐藏层到输出层（母鸡、小鸡的数量）  \n",
    "  \n",
    "    def forward(self, x):  \n",
    "        x = torch.relu(self.layer1(x))  \n",
    "        x = torch.relu(self.layer2(x))  \n",
    "        x = self.layer3(x)  # 输出层不使用激活函数，因为我们需要逼近连续值  \n",
    "        return x\n",
    "\n",
    "# 生成所有满足条件的（公鸡，母鸡，小鸡）组合  \n",
    "def generate_combinations():  \n",
    "    combinations = []  \n",
    "    for cock in range(21):  # 公鸡最多20只（5钱一只，最多100钱）  \n",
    "        for hen in range(34):  # 母鸡最多33只（3钱一只，最多99钱，留一只小鸡的可能性）  \n",
    "            chick = 100 - cock - hen  \n",
    "            if chick % 3 == 0 and 5 * cock + 3 * hen + chick / 3 == 100:  \n",
    "                combinations.append((cock, hen, chick))  \n",
    "    return combinations  \n",
    "\n",
    "# 从满足条件的组合中随机选择一部分作为训练数据  \n",
    "def select_random_samples(combinations, num_samples):  \n",
    "    random_samples = random.sample(combinations, num_samples)  \n",
    "    inputs = [sample[0] for sample in random_samples]  # 公鸡数量作为输入  \n",
    "    outputs = [(sample[1], sample[2]) for sample in random_samples]  # 母鸡和小鸡数量作为输出  \n",
    "    return torch.tensor(inputs, dtype=torch.float32), torch.tensor(outputs, dtype=torch.float32)  \n",
    "\n",
    "\n",
    "# 训练模型  \n",
    "def train_model(model, data, labels, epochs=10000, lr=0.01):\n",
    "    criterion = nn.MSELoss()  \n",
    "    optimizer = optim.SGD(model.parameters(), lr=lr)  \n",
    "    for epoch in range(epochs):  \n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(data)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()  \n",
    "\n",
    "        if (epoch+1) % 100 == 0:  \n",
    "            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data:tensor([ 4.,  0., 12.,  8.])\n",
      "labels:tensor([[18., 78.],\n",
      "        [25., 75.],\n",
      "        [ 4., 84.],\n",
      "        [11., 81.]])\n",
      "Epoch [100/1500], Loss: 409.4341\n",
      "Epoch [200/1500], Loss: 144.0071\n",
      "Epoch [300/1500], Loss: 84.8169\n",
      "Epoch [400/1500], Loss: 19.0063\n",
      "Epoch [500/1500], Loss: 13.4727\n",
      "Epoch [600/1500], Loss: 6.7440\n",
      "Epoch [700/1500], Loss: 2.2668\n",
      "Epoch [800/1500], Loss: 0.6730\n",
      "Epoch [900/1500], Loss: 0.2003\n",
      "Epoch [1000/1500], Loss: 0.0677\n",
      "Epoch [1100/1500], Loss: 0.0246\n",
      "Epoch [1200/1500], Loss: 0.0092\n",
      "Epoch [1300/1500], Loss: 0.0035\n",
      "Epoch [1400/1500], Loss: 0.0013\n",
      "Epoch [1500/1500], Loss: 0.0005\n"
     ]
    }
   ],
   "source": [
    "# 简单的任务，极致的享受（直接过拟合，无需泛化）\n",
    "model = MLP()  \n",
    "combinations = generate_combinations()  \n",
    "num_samples = len(combinations)  # 样本数量\n",
    "data, labels = select_random_samples(combinations, num_samples)\n",
    "print(f\"data:{data}\")\n",
    "print(f\"labels:{labels}\")\n",
    "train_model(model, data[:, None], labels, epochs=1500, lr=0.001)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "公鸡的数量: 0\n",
      "预测母鸡的数量: 25\n",
      "预测小鸡的数量: 75\n",
      "0×5+25×3+75÷3=100\n",
      "公鸡的数量: 4\n",
      "预测母鸡的数量: 18\n",
      "预测小鸡的数量: 78\n",
      "4×5+18×3+78÷3=100\n",
      "公鸡的数量: 8\n",
      "预测母鸡的数量: 11\n",
      "预测小鸡的数量: 81\n",
      "8×5+11×3+81÷3=100\n",
      "公鸡的数量: 12\n",
      "预测母鸡的数量: 4\n",
      "预测小鸡的数量: 84\n",
      "12×5+4×3+84÷3=100\n"
     ]
    }
   ],
   "source": [
    "# 批量测试数据\n",
    "test_data = torch.tensor([[0],[4],[8],[12]], dtype=torch.float32)\n",
    "for data in test_data:\n",
    "    # 测试无需梯度\n",
    "    with torch.no_grad():\n",
    "        # 测试模型\n",
    "        # test_input = torch.tensor([4], dtype=torch.float32)  # 输入\n",
    "        prediction = model(data)  \n",
    "        rounded_prediction = torch.round(prediction).int()  # 对输出进行四舍五入并取整\n",
    "        公鸡 = int(data.item())\n",
    "        母鸡 = rounded_prediction.numpy()[0]\n",
    "        小鸡 = rounded_prediction.numpy()[1]\n",
    "        print(\"公鸡的数量:\", 公鸡)\n",
    "        print(\"预测母鸡的数量:\", 母鸡)\n",
    "        print(\"预测小鸡的数量:\", 小鸡)\n",
    "        # 检测\n",
    "        print(f\"{公鸡}×5+{母鸡}×3+{小鸡}÷3={int(公鸡*5+母鸡*3+小鸡/3)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Animism",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
